#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright (c) 2016-2018, Niklas Hauser
# Copyright (c) 2017, Fabian Greif
# Copyright (c) 2020, Mike Wolfram
# Copyright (c) 2021, Raphael Lehmann
# Copyright (c) 2021-2023, Christopher Durand
#
# This file is part of the modm project.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
# -----------------------------------------------------------------------------

from collections import defaultdict
import re

def _get_signal_names(driver):
    signal_names = set()
    for request_data in driver["requests"]:
        for request in request_data["request"]:
            for signal in request["signal"]:
                if "name" in signal:
                    signal_name = signal["name"].capitalize()
                    signal_names.add(signal_name)
    return sorted(list(signal_names))


def _validate_channel_count(driver):
    channel_count = 8
    assert len(driver["mux-channels"]) == 1  # only one DMAMUX instance is supported
    channels = driver["mux-channels"][0]["mux-channel"]
    instance_channels = defaultdict(list)
    for channel in channels:
        instance_channels[channel.get("dma-instance")].append(channel["dma-channel"])
    for instance in instance_channels:
        channel_list = [int(c) for c in instance_channels[instance]]
        channel_list.sort()
        assert channel_list == list(range(0, channel_count))


def _get_substitutions(device, instance = ""):
    dma = device.get_driver("bdma:stm32*")
    return {
        "target" : device.identifier,
        "dma" : dma,
        "dma_signals" : _get_signal_names(dma),
        "channel_count" : 8,
        "instance" : instance
    }


class Instance(Module):
    def __init__(self, instance):
        self.instance = instance

    def init(self, module):
        module.name = str(self.instance)
        module.description = "Instance {}".format(self.instance)

    def prepare(self, module, options):
        return True

    def build(self, env):
        device = env[":target"]
        driver = device.get_driver("bdma")

        env.substitutions = _get_substitutions(device, str(self.instance))
        env.outbasepath = "modm/src/modm/platform/dma"

        env.template("bdma.hpp.in", "bdma_{}.hpp".format(self.instance))
        env.template("bdma.cpp.in", "bdma_{}.cpp".format(self.instance))


def init(module):
    module.name = ":platform:bdma"
    module.description = "Basic Direct Memory Access Controller (BDMA)"


def prepare(module, options):
    device = options[":target"]
    module.depends(":cmsis:device", ":platform:rcc")

    if not device.identifier.platform == "stm32":
        return False
    if not device.has_driver("bdma"):
        return False

    driver = device.get_driver("bdma")
    if "instance" in driver:
        for instance in listify(driver["instance"]):
            module.add_submodule(Instance(int(instance)))

    return True


def build(env):
    device = env[":target"]
    driver = device.get_driver("bdma")

    _validate_channel_count(driver)

    env.substitutions = _get_substitutions(device)
    env.outbasepath = "modm/src/modm/platform/dma"

    env.copy("bdma_hal.hpp")
    env.template("bdma_base.hpp.in")
    if "instance" not in driver:
        env.template("bdma.hpp.in")
        env.template("bdma.cpp.in")
